Skip to content

Fix zero prompt-cache reuse for thinking models in multi-turn chat#1042

Open
lyonsno wants to merge 8 commits intoml-explore:mainfrom
lyonsno:b1-replay-upstream-main
Open

Fix zero prompt-cache reuse for thinking models in multi-turn chat#1042
lyonsno wants to merge 8 commits intoml-explore:mainfrom
lyonsno:b1-replay-upstream-main

Conversation

@lyonsno
Copy link
Copy Markdown
Contributor

@lyonsno lyonsno commented Mar 22, 2026

The problem

On main, thinking models get zero prompt-cache reuse in multi-turn
chat. This affects any model whose cache layers include non-trimmable types.
In practice, this means Qwen 3.5, Step 3.5 Flash, GPT-OSS, and similar
models are stuck at 0% cache reuse in multi-turn chat.

The root cause: when a thinking model responds, the server caches
prompt + <think>...</think> + response as the key. On the next turn, the
chat template strips the thinking content, so the new prompt diverges from
the cached key at the first thinking token. The server finds the longer
cache and tries can_trim_prompt_cache — but that requires all layers
to be trimmable, and ArraysCache returns is_trimmable=False. Result:
the longer cache can't be trimmed, there's no shorter match, and the entire
prompt is re-prefilled from scratch.

Measured impact (Qwen 3.5 35B-A3B BF16, 40-turn conversation, 64 tok/turn):

upstream/main This PR
Turn 10 cache hit 0% 88.7%
Turn 40 cache hit 0% 97.3%
Total tokens skipped (turns 2–40) 0/64,444 61,116/64,434

What this PR does

This PR restores prompt-cache reuse in cases that currently degenerate to full prefill for thinking and hybrid-attention models.

1. Checkpoint caching — After processing a chat turn whose last message
is from the user, the server saves a KV cache checkpoint just before the
thinking-start token. On the next turn, the checkpoint matches the prompt
prefix (before thinking was stripped), enabling reuse regardless of cache
layer types. Checkpoints are stored in the existing LRUPromptCache and
evicted under normal LRU pressure — no additional unbounded state is
introduced.

2. Rewind-based longer-cache reusefetch_nearest_cache can now
rewind a longer cached prompt to match a shorter request using per-layer
can_rewind/rewind instead of the all-or-nothing can_trim_prompt_cache.
This enables cache reuse for BatchRotatingKVCache (sliding-window
attention) and falls back gracefully through the legacy is_trimmable/trim
contract for other cache types. Rewind is only applied when per-layer
invariants guarantee equivalence with prefix truncation; otherwise we fall
back to a cache miss. A fail-closed precheck avoids deepcopy on guaranteed
misses.

3. Refcounted extraction — Cache entries track an insertion count.
Checkpoint entries are persistent (extraction deepcopies, original stays),
so a checkpoint survives across turns until evicted by LRU pressure.
Regular entries are consumed on last reference. This prevents handing out
the same mutable cache to concurrent requests.

4. Batch checkpoint compatibilityBatchGenerator.insert() accepts
per-prompt checkpoint positions. Prompts whose checkpoint tails are
incompatible are truncated to the largest compatible prefix rather than
silently generating invalid checkpoints.

Benchmark results

All benchmarks: MacBook Pro 16" M4 Max (40-core GPU), 128 GB unified memory.

40-turn multi-turn chat (Qwen 3.5 35B-A3B BF16, 64 tokens generated
per turn, short user messages, LRUPromptCache default max_size=10):

Turn Prompt tokens This PR hit% upstream hit% This PR cache upstream cache
1 21 0% 0% 139 MB 70 MB
5 344 75.6% 0% 716 MB 359 MB
10 750 88.7% 0% 791 MB 759 MB
15 1,161 92.9% 0% 879 MB 843 MB
20 1,567 94.7% 0% 967 MB 932 MB
30 2,389 96.5% 0% 1,134 MB 1,095 MB
40 3,201 97.3% 0% 1,298 MB 1,263 MB

Overall: this PR skips 61,116 of 64,434 prompt tokens (94.9%). Upstream
skips 0.

Wall-clock timing (same 40-turn run, measured prefill and generation
separately):

This PR upstream/main
Prefill 10.2s 40.9s
Generation 48.9s 48.6s
Total 59.1s 89.5s

4x prefill speedup, 34% reduction in total time. Generation speed is
unaffected by checkpoint caching.

Multi-model comparison (6 turns, 256 tokens generated per turn):

Model Architecture This PR (turn 6) upstream/main (turn 6)
Qwen 0.5B 4-bit Full attention 94.8% 94.1%
Qwen 3.5 35B BF16 Full attention + thinking 80.4% 0%
Step 3.5 Flash mixed-p Full + SWA + thinking 80.5% 0%

Branching conversation (shared prefix, diverging follow-ups):

Model Architecture This PR upstream/main
Step 3.5 Flash mixed-p Full + SWA + thinking 52.1% (rewind) 0%

For non-thinking full-attention models, cache reuse is unchanged — upstream
already handles this well. The improvement targets thinking models and
hybrid attention architectures, which currently get no reuse at all.

Memory overhead

Both branches hold 10 LRU entries at steady state — upstream already pays
the cost of caching completions it can never reuse. The additional memory
from checkpoint entries is ~35 MB (< 3% overhead at turn 40). Cache
entry count stays flat at 10 from turn 5 onward; memory grows sub-linearly
as older entries are evicted and replaced by slightly larger ones reflecting
the growing conversation context.

Checkpoint creation is gated on tokenizer.has_thinking. Non-thinking
models create no checkpoint entries — verified on Qwen 0.5B over 30 turns
with identical cache entry counts, memory usage, and hit rates vs upstream.

Behavioral changes from upstream

  • LRUPromptCache.insert_cache no longer evicts shorter prefix entries
    when a longer entry is inserted. Both coexist under normal LRU pressure.
    This is required for checkpoint entries which must not be displaced by
    their longer completions.
  • Shorter-prefix fetches now consume the entry when its refcount reaches 1,
    where previously they deepcopied and left the entry in place.
  • _search returns shorter matches at index 0 (single-token prefix).
    Previously required index > 0.
  • _lazy_extract_cache (generator for lazy cache extraction) replaced with
    eager list comprehension to avoid late-binding issues in checkpoint
    callbacks.

Size

~330 lines of production changes across 3 files, ~1,800 lines of test
coverage across 6 test files. The test surface includes:

  • Black-box LRUPromptCache behavior tests (rewind ranking, recency
    refresh, extraction lifecycle, partial-rewind safety)
  • BatchRotatingKVCache rewind + mask correctness after rotation
  • BatchGenerator checkpoint callback contract tests with real model
    inference (resume-from-checkpoint reproduces full-prompt logprobs)
  • ResponseGenerator integration tests (checkpoint forwarding/suppression,
    warm-cache localization, malformed-request worker survival)

Tests

pytest tests/test_prompt_cache.py
pytest tests/test_prompt_cache_server_behavior.py
pytest tests/test_prompt_cache_server_rewind_internal.py
pytest tests/test_generate.py
pytest tests/test_server.py

Full suite: 207 passed, 1 skipped, 3 warnings

lyonsno and others added 3 commits March 21, 2026 21:10
Pin behavioral contracts for review findings: checkpoint persistence
through repeated extraction, partial rewind safety on longer hits,
refcount lifecycle, deepcopy failure resilience, single-token shorter
match threshold, prefix non-eviction on longer insert, and checkpoint
localization suppression at prompt boundaries.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lyonsno lyonsno marked this pull request as ready for review March 23, 2026 10:04
@lyonsno lyonsno changed the title Improve prompt-cache checkpoint reuse in server and batch generation Fix prompt-cache reuse for thinking models (and add checkpoint caching) Mar 23, 2026
@lyonsno lyonsno changed the title Fix prompt-cache reuse for thinking models (and add checkpoint caching) Fix zero prompt-cache reuse for thinking models in multi-turn chat Mar 23, 2026
Non-thinking models get no benefit from checkpoint caching (their cache
keys don't diverge between turns), so storing checkpoint entries is pure
memory overhead. Gate checkpoint creation on tokenizer.has_thinking to
eliminate unnecessary cache growth for standard models.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lyonsno
Copy link
Copy Markdown
Contributor Author

lyonsno commented Mar 23, 2026

There is partial overlap with #1006, which addresses one slice of this problem space (0% cache reuse for hybrid models with non-trimmable cache types). Both PRs were developed independently in response to #980.

#1006 handles the non-trimmable checkpoint case. This PR also handles thinking-token divergence (where chat templates strip <think> content between turns), rewind-based longer-cache reuse for sliding-window models, refcounted extraction for concurrent request safety, and batch checkpoint compatibility for mixed-length prefill batches. Test coverage is ~1,800 lines across 6 test files, including end-to-end logprob-equivalence checks against a real model.

Benchmark methodology and results are in the PR description. Harness source is at lyonsno/mlx-lm-benchmarks.

Non-thinking models with non-trimmable caches (ArraysCache) need the
checkpoint entry to enable cache reuse via the shorter-cache path.
The early return for non-thinking models was a regression from upstream
behavior where _compute_prompt_checkpoint always returns (True, -1)
for user-terminal chat requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Thump604
Copy link
Copy Markdown

Nice work. We run Qwen3.5-122B with thinking enabled on M2 Ultra 128GB and the 0% cache reuse on multi-turn is one of our top pain points. The checkpoint-at-think-boundary approach is the right call.

Code review notes:

Strengths:

  • Test coverage is excellent. The real-model logprob-equivalence tests with Qwen 0.5B give strong correctness confidence. Mock layers for edge cases (partial rewind failure, legacy fallback) cover the important failure modes.
  • The BatchRotatingKVCache.trim() fix for left_padding on rotated caches looks like a genuine bug fix that should land regardless of the rest of the PR.
  • Fail-closed design on rewind: deepcopy, attempt rewind, discard on failure. Original cache stays intact.

Minor concerns:

  • _can_rewind_layer_cache introspects type(layer_cache).rewind is not _BaseCache.rewind to detect custom overrides. This is fragile if someone subclasses and delegates without overriding. A simpler protocol check (hasattr + explicit supports_rewind flag, or just try/except) might be more robust.
  • The count field on CacheEntry reads like a refcount but is really an insertion counter. A rename to insertion_count or a brief docstring would help readers.
  • insert_cache no longer evicts shorter prefixes. This is correct for checkpoints but increases LRU pressure in non-thinking scenarios. Probably fine in practice since LRU eviction handles it, but worth noting.

MTP compatibility: The changes to cache.py are additive (new methods on BatchRotatingKVCache). Since we use an editable mlx-lm with MTP support (PR #990), these would flow in cleanly. The standard KVCache used by MTP layers has is_trimmable()=True + trim(), so the legacy fallback path handles them correctly.

We can benchmark on the 122B (M2 Ultra 128GB) once this is ready to test. Happy to report TTFT at turn 5/10/20 with thinking mode on Qwen3.5.

@Thump604
Copy link
Copy Markdown

Production TTFT data from M2 Ultra 128GB, Qwen3.5-122B-A10B 5-bit, thinking mode enabled.

10-turn coding conversation (progressive refinement of a Python module):

Turn TTFT (s) Content Reasoning
1 27.83 0 1706
2 24.32 0 1904
3 9.72 452 179
4 17.83 810 308
5 8.42 347 168
6 25.77 1325 462
7 21.88 1227 266
8 26.21 1502 251
9 26.42 1478 380
10 26.54 1474 296

TTFT growth factor: 0.95x (flat, not growing). We're running vllm-mlx's prefix cache, not mlx-lm's LRUPromptCache, so the comparison isn't direct. But it confirms the problem: turns with heavy thinking content (1-2, 6-10) show 25-28s TTFT on 122B. The checkpoint approach in this PR would directly address the thinking-token divergence causing cache misses.

We'll do a proper A/B test with the full PR applied when it's closer to merge. For now the cache.py changes (BatchRotatingKVCache rewind/trim fix) are applied to our editable mlx-lm without issues.

lyonsno and others added 2 commits March 24, 2026 21:57
Metal allocator non-determinism causes the prompt-path subtest to
flake at 1.35x. A real memory leak over 120 steps would be 10x+,
so 2.0x still catches the failure mode without false positives.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address reviewer feedback from PR ml-explore#1042:

- CacheEntry.count → ref_count: the field is decremented on extraction,
  so it's a reference count, not an insertion counter.

- Add default rewind() on _BaseCache and a _has_rewind_impl() helper
  that uses method identity to detect real overrides. This replaces the
  inline introspection in _can_rewind_layer_cache with a cleaner helper
  while preserving the same behavior: third-party _BaseCache subclasses
  that implement rewind() participate automatically without needing an
  explicit opt-in flag.

- Add targeted tests for the _has_rewind_impl contract covering base
  class, no-override subclass, custom override, and BatchRotatingKVCache.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lyonsno
Copy link
Copy Markdown
Contributor Author

lyonsno commented Mar 25, 2026

Thanks for the thoughtful review. I incorporated the two code-level suggestions that seemed worth tightening.

For the rewind-path check, I refactored the _BaseCache contract so it now has a default rewind() plus a _has_rewind_impl() helper that uses method identity to detect real overrides (6070f37). That keeps the existing fail-closed behavior while making the intent clearer, and it means third-party _BaseCache subclasses that implement rewind() participate automatically without needing an extra opt-in flag. I also added targeted tests covering the base class, a subclass without an override, a subclass with a real override, and BatchRotatingKVCache.

I also renamed CacheEntry.count to ref_count. You were right that count was too vague; this field is decremented on extraction, so ref_count is a better description.

On shorter-prefix retention, I left the behavior unchanged. That was a deliberate tradeoff rather than an oversight: checkpoint entries need shorter prefixes to coexist, and LRU eviction already bounds the pressure. If someone later shows a workload where coexisting shorter prefixes measurably hurt hit rate or memory behavior, I'd be happy to revisit that separately.

Good to hear that the fallback path through is_trimmable()/trim() lines up cleanly with your MTP setup. And yes, I'd definitely welcome the 122B A/B when you get a chance. TTFT at turns 5/10/20 with thinking enabled would be especially useful.

@Thump604
Copy link
Copy Markdown

Merged this PR into our editable mlx-lm (feat/mtp-native branch, Qwen3.5-122B). The merge was clean -- one trivial conflict in qwen3_5.py (our MTP _process_chunk refactor vs your cache.advance(S) addition, both kept).

15-turn benchmark on M2 Ultra 128GB, Qwen3.5-122B-A10B 5-bit, thinking enabled, vllm-mlx BatchedEngine:

Turn TTFT (s) Content Reasoning
1 38.1 0 1720
5 28.4 403 144
10 45.7 1690 192
15 24.4 23 164

TTFT growth factor: 0.64x (TTFT not growing across turns). Note: TTFTs elevated vs earlier run due to concurrent load testing sharing the server.

The ref_count rename and _has_rewind_impl refactor look good. The BatchRotatingKVCache.rewind() and trim() left_padding fix are clean and additive -- no issues with our MTP cache layers.

For a proper isolated A/B (with vs without PR, same workload, no concurrent load), we can run that when things quiet down. The data above confirms correctness on 122B with thinking mode.

@Thump604
Copy link
Copy Markdown

Clean 15-turn benchmark (solo, no concurrent load). M2 Ultra 128GB, Qwen3.5-122B-A10B 5-bit, thinking enabled, vllm-mlx BatchedEngine:

Turn TTFT (s) Content Reasoning
1 7.2 300 175
2 6.4 262 137
3 12.2 654 177
4 25.0 1288 175
5 25.4 1715 163
10 26.2 1739 132
13 4.9 23 215
15 4.9 23 164

TTFT growth factor: 0.68x. TTFT plateaus around 26s for heavy turns (1500+ chars content) and drops to ~5s for short responses. No degradation across turns.

This is through vllm-mlx's prefix cache (not mlx-lm's LRUPromptCache), so the checkpoint mechanism in this PR isn't directly exercised. For a proper A/B with mlx-lm's server, I'll need to temporarily swap servers. Can do that in a follow-up session.

The cache.py changes (rewind/trim) are confirmed working on 122B MTP with thinking enabled. No issues.

@Thump604
Copy link
Copy Markdown

PR #1042 A/B Test: mlx-lm server on 122B (as requested)

Here's the proper A/B using mlx-lm's server (not vllm-mlx), exercising the LRUPromptCache checkpoint logic directly.

Setup: M2 Ultra 128GB, Qwen3.5-122B-A10B 5-bit, python -m mlx_lm.server, thinking enabled, 15-turn coding conversation.

Turn TTFT (s) Content (chars)
1 36.6 247
5 3.7 1108
10 7.9 7047
15 10.4 7087

TTFT growth factor: 0.28x (sublinear -- prefix cache is working across turns).

The key signal: Turn 1 cold-starts at 36.6s, Turn 5 drops to 3.7s (10x improvement from cache hit), and by Turn 15 TTFT is only 10.4s despite accumulating ~60K chars of conversation context. Without PR #1042, thinking models would get 0% cache reuse and TTFT would grow linearly with conversation length.

Bug found and fixed

During testing, _can_rewind_layer_cache crashed with a NameError on the legacy fallback path (line 347):

# Line 347 - `rewind` was never defined in scope
if not callable(is_trimmable) or (not callable(trim) and not callable(rewind)):

Fix: add rewind = getattr(layer_cache, "rewind", None) before line 347. This path is hit when a layer cache doesn't implement can_rewind() but does implement the legacy is_trimmable()/trim()/rewind() contract. With Qwen3.5's MTP cache layers, the can_rewind path is used, but the fallback crashes for any third-party cache that only implements the legacy interface.

@Thump604
Copy link
Copy Markdown

Code audit: additional findings

While testing on 122B, I audited the prompt cache and batch paths more carefully. Beyond the NameError fix reported above, here are findings worth addressing:

1. _rewind_layer_cache disagrees with _can_rewind_layer_cache (logic bug)

For cache types with is_trimmable()/trim() but no custom can_rewind/rewind (KVCache, RotatingKVCache, etc.):

  • _can_rewind_layer_cache correctly falls to the legacy path (line 345+) because can_rewind is None. Returns True.
  • _rewind_layer_cache finds rewind is callable (inherited from _BaseCache, raises NotImplementedError). Enters try at line 374, calls rewind(), catches NotImplementedError, returns False. Never reaches its own legacy trim fallback (lines 386-395).

Impact: In fetch_nearest_cache, for any "longer" cached entry with these cache types, _can_rewind says "yes", a copy.deepcopy is made (expensive!), then _rewind returns False, and the deepcopy is discarded. The longer-match rewind optimization is non-functional for all cache types except BatchRotatingKVCache.

Fix: Check _has_rewind_impl() before calling rewind() in _rewind_layer_cache, mirroring the logic in _can_rewind_layer_cache.

2. args.top_logprobs uses wrong request in batch loop (line 1128)

_format_top_logprobs(r.logprobs, args.top_logprobs, current_tokenizer)

args is from the last request dequeued (line 979), not the request that produced batch response r. Each request in the batch could have a different top_logprobs value. Should store top_logprobs per-request when inserted into the batch.

3. CacheOrder.pop() crash on empty deques (line 196-199)

When both _lru and _lru_checkpoints are empty, 0 >= 0 is True, and _lru.popleft() raises IndexError. Called from trim_to which has no guard against empty cache. Trigger: trim_to(n_bytes=0) when _n_bytes is out of sync.

4. Legacy fallback in _rewind_layer_cache is dead code (lines 386-395)

Related to #1. The fallback path is unreachable for any _BaseCache subclass because _BaseCache defines rewind() (raises NotImplementedError), making getattr(layer_cache, "rewind", None) always callable. The if callable(rewind): at line 373 always enters the try/except. Only runs for non-_BaseCache custom caches.

These were found during production testing on Qwen3.5-122B. Happy to discuss any of them.

_rewind_layer_cache now checks _has_rewind_impl() before calling
rewind(), matching _can_rewind_layer_cache's logic. Previously,
_BaseCache subclasses with only trim() (KVCache, RotatingKVCache)
would have _can_rewind say yes but _rewind fail via the stub,
wasting a deepcopy. Also adds missing getattr for rewind in the
legacy fallback of _can_rewind_layer_cache.

Six agreement tests added covering legacy non-BaseCache caches,
BaseCache-with-trim-no-rewind, and BatchRotatingKVCache.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@lyonsno
Copy link
Copy Markdown
Contributor Author

lyonsno commented Mar 25, 2026

Thanks for both the isolated mlx_lm.server A/B and the follow-up audit. That 0.28x TTFT growth factor on 122B is exactly the kind of signal that helps de-risk this PR, and I really appreciate you taking the time to swap servers and isolate the run.

On the audit findings:

1 & 4 (_rewind_layer_cache / _can_rewind_layer_cache disagreement + dead legacy fallback): Good catches. These are two faces of the same bug. _rewind_layer_cache now checks _has_rewind_impl() before attempting rewind(), so _BaseCache subclasses with only trim() (KVCache, RotatingKVCache) correctly fall through to the legacy path instead of hitting the base stub and wasting a deepcopy. The missing rewind = getattr(...) on the legacy fallback path is also fixed in the same commit. Six agreement tests were added covering legacy non-_BaseCache caches, _BaseCache-with-trim()-but-no-rewind, and BatchRotatingKVCache. Pushed in aa0310b.

2 (top_logprobs batch scoping): Agreed — pre-existing and worth a separate fix.

3 (CacheOrder.pop() on empty deques): Opened separately as #1054 off upstream/main, with no dependency on this PR.

And thanks again for digging deeply enough to catch the legacy fallback NameError.

@Thump604
Copy link
Copy Markdown

@angeloskath @awni — this PR has been open since March 18 with no maintainer review. The author (lyonsno) has been very responsive to feedback, addressing all code review findings promptly including bug fixes with new test coverage.

Is there a concern with the approach or scope? Happy to help if anything is needed to move this forward.

Context: thinking models currently get 0% prompt cache reuse on multi-turn conversations because thinking tokens diverge between turns. This PR fixes that with checkpoint-based rewind. We validated on 122B production (0.28x TTFT growth factor across 15 turns) and found + reported 5 bugs during testing, all addressed by the author.

@angeloskath
Copy link
Copy Markdown
Member

Sorry for taking long to reply. #1072 should be taking care of this. Big part of it was already taken care of on main but either way feel free to provide a snippet on #1072 that shows the cache not being properly used if you are still experiencing issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants